import numpy as np
import os
from Record.file_management import read_obj_dumps, load_from_pickle, save_to_pickle, create_directory
from Buffer.fill_buffer import fill_buffer
from Buffer.all_buffer import AllReplayBuffer
from Vae.record_vae_and_masks_state import load_encodings

def set_batch(buffer, batch):
    # sets a batch and also sets internal variables of the buffer (TODO: breaks TS abstraction boundary, but that's because of TS improper design)
    buffer.set_batch(batch)
    buffer._index = 0
    buffer._size = len(batch)

def train_test_indices(args, buffer):
    if args.train.train_test_order == "random":
        train_indices = np.random.choice(list(range(len(buffer))), size=int(len(buffer) * args.train.train_test_ratio), replace=False)
        test_indices = [i for i in range(len((buffer))) if i not in train_indices]
    elif args.train.train_test_order == "time":
        train_indices, test_indices = list(range(len(buffer)))[:int(len(buffer) * args.train.train_test_ratio)], list(range(len(buffer)))[int(len(buffer) * args.train.train_test_ratio):]
    else:
        raise ValueError("invalid ordering setting")
    return train_indices, test_indices

def generate_buffers(environment, args, extractor, norm, train=True):
    # load data
    data = read_obj_dumps(args.train.load_rollouts, i=-1, rng = args.train.num_frames, filename=args.train.load_filename)
    encodings = load_encodings(args.train.load_encodings)[- args.train.num_frames:] if len(args.train.load_encodings) > 0 else None
    # get the buffers, assumes that outcome variable is train_infer[0]
    buffer = fill_buffer(data, environment, args, extractor, norm, encodings= encodings, encoding_length = args.image_enc.encoding_dim, outcome_variable=args.inter.train_names[0] if len(args.inter.train_names[0]) > 0 else extractor.names)
    if not train: return None, buffer

    # get indices for train/test, there are various settings for this
    train_indices, test_indices = train_test_indices(args, buffer)

    # fill the train/test buffer
    if args.record.save_intermediate: save_to_pickle(os.path.join(create_directory(args.record.save_intermediate), environment.name + "_mask_rollouts.pkl"), buffer)
    # fill buffers for the train set
    train_buffer = AllReplayBuffer(len(train_indices), stack_num=1)
    set_batch(train_buffer, buffer[train_indices])
    # fill buffers for the test set
    test_buffer = AllReplayBuffer(len(test_indices), stack_num=1)
    set_batch(test_buffer, buffer[test_indices])
    del buffer
    return train_buffer, test_buffer
